# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for inspect_utils module."""

import abc
import collections
import functools
import types
import textwrap
import unittest

import six

from nvidia.dali._autograph.pyct import inspect_utils
from nvidia.dali._autograph.pyct.testing import basic_definitions
from nvidia.dali._autograph.pyct.testing import decorators


def decorator(f):
    return f


def function_decorator():
    def dec(f):
        return f

    return dec


def wrapping_decorator():
    def dec(f):
        def replacement(*_):
            return None

        @functools.wraps(f)
        def wrapper(*args, **kwargs):
            return replacement(*args, **kwargs)

        return wrapper

    return dec


class TestClass(object):
    def member_function(self):
        pass

    @decorator
    def decorated_member(self):
        pass

    @function_decorator()
    def fn_decorated_member(self):
        pass

    @wrapping_decorator()
    def wrap_decorated_member(self):
        pass

    @staticmethod
    def static_method():
        pass

    @classmethod
    def class_method(cls):
        pass


def free_function():
    pass


def factory():
    return free_function


def free_factory():
    def local_function():
        pass

    return local_function


class InspectUtilsTest(unittest.TestCase):
    def test_islambda(self):
        def test_fn():
            pass

        self.assertTrue(inspect_utils.islambda(lambda x: x))
        self.assertFalse(inspect_utils.islambda(test_fn))

    def test_islambda_renamed_lambda(self):
        l = lambda x: 1
        l.__name__ = "f"
        self.assertTrue(inspect_utils.islambda(l))

    def test_isnamedtuple(self):
        nt = collections.namedtuple("TestNamedTuple", ["a", "b"])

        class NotANamedTuple(tuple):
            pass

        self.assertTrue(inspect_utils.isnamedtuple(nt))
        self.assertFalse(inspect_utils.isnamedtuple(NotANamedTuple))

    def test_isnamedtuple_confounder(self):
        """This test highlights false positives when detecting named tuples."""

        class NamedTupleLike(tuple):
            _fields = ("a", "b")

        self.assertTrue(inspect_utils.isnamedtuple(NamedTupleLike))

    def test_isnamedtuple_subclass(self):
        """This test highlights false positives when detecting named tuples."""

        class NamedTupleSubclass(collections.namedtuple("Test", ["a", "b"])):
            pass

        self.assertTrue(inspect_utils.isnamedtuple(NamedTupleSubclass))

    def assertSourceIdentical(self, actual, expected):
        self.assertEqual(textwrap.dedent(actual).strip(), textwrap.dedent(expected).strip())

    def test_getimmediatesource_basic(self):
        def test_decorator(f):
            def f_wrapper(*args, **kwargs):
                return f(*args, **kwargs)

            return f_wrapper

        expected = """
    def f_wrapper(*args, **kwargs):
        return f(*args, **kwargs)
    """

        @test_decorator
        def test_fn(a):
            """Test docstring."""
            return [a]

        self.assertSourceIdentical(inspect_utils.getimmediatesource(test_fn), expected)

    def test_getimmediatesource_noop_decorator(self):
        def test_decorator(f):
            return f

        expected = '''
    @test_decorator
    def test_fn(a):
        """Test docstring."""
        return [a]
    '''

        @test_decorator
        def test_fn(a):
            """Test docstring."""
            return [a]

        self.assertSourceIdentical(inspect_utils.getimmediatesource(test_fn), expected)

    def test_getimmediatesource_functools_wrapper(self):
        def wrapper_decorator(f):
            @functools.wraps(f)
            def wrapper(*args, **kwargs):
                return f(*args, **kwargs)

            return wrapper

        expected = textwrap.dedent(
            """
    @functools.wraps(f)
    def wrapper(*args, **kwargs):
        return f(*args, **kwargs)
    """
        )

        @wrapper_decorator
        def test_fn(a):
            """Test docstring."""
            return [a]

        self.assertSourceIdentical(inspect_utils.getimmediatesource(test_fn), expected)

    def test_getimmediatesource_functools_wrapper_different_module(self):
        expected = textwrap.dedent(
            """
    @functools.wraps(f)
    def wrapper(*args, **kwargs):
        return f(*args, **kwargs)
    """
        )

        @decorators.wrapping_decorator
        def test_fn(a):
            """Test docstring."""
            return [a]

        self.assertSourceIdentical(inspect_utils.getimmediatesource(test_fn), expected)

    def test_getimmediatesource_normal_decorator_different_module(self):
        expected = textwrap.dedent(
            """
    def standalone_wrapper(*args, **kwargs):
        return f(*args, **kwargs)
    """
        )

        @decorators.standalone_decorator
        def test_fn(a):
            """Test docstring."""
            return [a]

        self.assertSourceIdentical(inspect_utils.getimmediatesource(test_fn), expected)

    def test_getimmediatesource_normal_functional_decorator_different_module(self):
        expected = textwrap.dedent(
            """
    def functional_wrapper(*args, **kwargs):
        return f(*args, **kwargs)
    """
        )

        @decorators.functional_decorator()
        def test_fn(a):
            """Test docstring."""
            return [a]

        self.assertSourceIdentical(inspect_utils.getimmediatesource(test_fn), expected)

    def test_getnamespace_globals(self):
        ns = inspect_utils.getnamespace(factory)
        self.assertEqual(ns["free_function"], free_function)

    def test_getnamespace_closure_with_undefined_var(self):
        if False:  # pylint:disable=using-constant-test
            a = 1

        def test_fn():
            return a

        ns = inspect_utils.getnamespace(test_fn)
        self.assertNotIn("a", ns)

        a = 2
        ns = inspect_utils.getnamespace(test_fn)

        self.assertEqual(ns["a"], 2)

    def test_getnamespace_hermetic(self):
        # Intentionally hiding the global function to make sure we don't overwrite
        # it in the global namespace.
        free_function = object()  # pylint:disable=redefined-outer-name

        def test_fn():
            return free_function

        ns = inspect_utils.getnamespace(test_fn)
        globs = six.get_function_globals(test_fn)
        self.assertTrue(ns["free_function"] is free_function)
        self.assertFalse(globs["free_function"] is free_function)

    def test_getnamespace_locals(self):
        def called_fn():
            return 0

        closed_over_list = []
        closed_over_primitive = 1

        def local_fn():
            closed_over_list.append(1)
            local_var = 1
            return called_fn() + local_var + closed_over_primitive

        ns = inspect_utils.getnamespace(local_fn)
        self.assertEqual(ns["called_fn"], called_fn)
        self.assertEqual(ns["closed_over_list"], closed_over_list)
        self.assertEqual(ns["closed_over_primitive"], closed_over_primitive)
        self.assertTrue("local_var" not in ns)

    def test_getqualifiedname(self):
        foo = object()
        qux = types.ModuleType("quxmodule")
        bar = types.ModuleType("barmodule")
        baz = object()
        bar.baz = baz

        ns = {
            "foo": foo,
            "bar": bar,
            "qux": qux,
        }

        self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils))
        self.assertEqual(inspect_utils.getqualifiedname(ns, foo), "foo")
        self.assertEqual(inspect_utils.getqualifiedname(ns, bar), "bar")
        self.assertEqual(inspect_utils.getqualifiedname(ns, baz), "bar.baz")

    def test_getqualifiedname_efficiency(self):
        foo = object()

        # We create a densely connected graph consisting of a relatively small
        # number of modules and hide our symbol in one of them. The path to the
        # symbol is at least 10, and each node has about 10 neighbors. However,
        # by skipping visited modules, the search should take much less.
        ns = {}
        prev_level = []
        for i in range(10):
            current_level = []
            for j in range(10):
                mod_name = "mod_{}_{}".format(i, j)
                mod = types.ModuleType(mod_name)
                current_level.append(mod)
                if i == 9 and j == 9:
                    mod.foo = foo
            if prev_level:
                # All modules at level i refer to all modules at level i+1
                for prev in prev_level:
                    for mod in current_level:
                        prev.__dict__[mod.__name__] = mod
            else:
                for mod in current_level:
                    ns[mod.__name__] = mod
            prev_level = current_level

        self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils))
        self.assertIsNotNone(inspect_utils.getqualifiedname(ns, foo, max_depth=10000000000))

    def test_getqualifiedname_cycles(self):
        foo = object()

        # We create a graph of modules that contains circular references. The
        # search process should avoid them. The searched object is hidden at the
        # bottom of a path of length roughly 10.
        ns = {}
        mods = []
        for i in range(10):
            mod = types.ModuleType("mod_{}".format(i))
            if i == 9:
                mod.foo = foo
            # Module i refers to module i+1
            if mods:
                mods[-1].__dict__[mod.__name__] = mod
            else:
                ns[mod.__name__] = mod
            # Module i refers to all modules j < i.
            for prev in mods:
                mod.__dict__[prev.__name__] = prev
            mods.append(mod)

        self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils))
        self.assertIsNotNone(inspect_utils.getqualifiedname(ns, foo, max_depth=10000000000))

    def test_getmethodclass(self):
        self.assertEqual(inspect_utils.getmethodclass(free_function), None)
        self.assertEqual(inspect_utils.getmethodclass(free_factory()), None)

        self.assertEqual(inspect_utils.getmethodclass(TestClass.member_function), TestClass)
        self.assertEqual(inspect_utils.getmethodclass(TestClass.decorated_member), TestClass)
        self.assertEqual(inspect_utils.getmethodclass(TestClass.fn_decorated_member), TestClass)
        self.assertEqual(inspect_utils.getmethodclass(TestClass.wrap_decorated_member), TestClass)
        self.assertEqual(inspect_utils.getmethodclass(TestClass.static_method), TestClass)
        self.assertEqual(inspect_utils.getmethodclass(TestClass.class_method), TestClass)

        test_obj = TestClass()
        self.assertEqual(inspect_utils.getmethodclass(test_obj.member_function), TestClass)
        self.assertEqual(inspect_utils.getmethodclass(test_obj.decorated_member), TestClass)
        self.assertEqual(inspect_utils.getmethodclass(test_obj.fn_decorated_member), TestClass)
        self.assertEqual(inspect_utils.getmethodclass(test_obj.wrap_decorated_member), TestClass)
        self.assertEqual(inspect_utils.getmethodclass(test_obj.static_method), TestClass)
        self.assertEqual(inspect_utils.getmethodclass(test_obj.class_method), TestClass)

    def test_getmethodclass_locals(self):
        def local_function():
            pass

        class LocalClass(object):
            def member_function(self):
                pass

            @decorator
            def decorated_member(self):
                pass

            @function_decorator()
            def fn_decorated_member(self):
                pass

            @wrapping_decorator()
            def wrap_decorated_member(self):
                pass

        self.assertEqual(inspect_utils.getmethodclass(local_function), None)

        self.assertEqual(inspect_utils.getmethodclass(LocalClass.member_function), LocalClass)
        self.assertEqual(inspect_utils.getmethodclass(LocalClass.decorated_member), LocalClass)
        self.assertEqual(inspect_utils.getmethodclass(LocalClass.fn_decorated_member), LocalClass)
        self.assertEqual(inspect_utils.getmethodclass(LocalClass.wrap_decorated_member), LocalClass)

        test_obj = LocalClass()
        self.assertEqual(inspect_utils.getmethodclass(test_obj.member_function), LocalClass)
        self.assertEqual(inspect_utils.getmethodclass(test_obj.decorated_member), LocalClass)
        self.assertEqual(inspect_utils.getmethodclass(test_obj.fn_decorated_member), LocalClass)
        self.assertEqual(inspect_utils.getmethodclass(test_obj.wrap_decorated_member), LocalClass)

    def test_getmethodclass_callables(self):
        class TestCallable(object):
            def __call__(self):
                pass

        c = TestCallable()
        self.assertEqual(inspect_utils.getmethodclass(c), TestCallable)

    def test_getdefiningclass(self):
        class Superclass(object):
            def foo(self):
                pass

            def bar(self):
                pass

            @classmethod
            def class_method(cls):
                pass

        class Subclass(Superclass):
            def foo(self):
                pass

            def baz(self):
                pass

        self.assertIs(inspect_utils.getdefiningclass(Subclass.foo, Subclass), Subclass)
        self.assertIs(inspect_utils.getdefiningclass(Subclass.bar, Subclass), Superclass)
        self.assertIs(inspect_utils.getdefiningclass(Subclass.baz, Subclass), Subclass)
        self.assertIs(inspect_utils.getdefiningclass(Subclass.class_method, Subclass), Superclass)

    def test_isbuiltin(self):
        self.assertTrue(inspect_utils.isbuiltin(enumerate))
        self.assertTrue(inspect_utils.isbuiltin(eval))
        self.assertTrue(inspect_utils.isbuiltin(float))
        self.assertTrue(inspect_utils.isbuiltin(int))
        self.assertTrue(inspect_utils.isbuiltin(len))
        self.assertTrue(inspect_utils.isbuiltin(range))
        self.assertTrue(inspect_utils.isbuiltin(zip))
        self.assertFalse(inspect_utils.isbuiltin(function_decorator))

    def test_isconstructor(self):
        class OrdinaryClass(object):
            pass

        class OrdinaryCallableClass(object):
            def __call__(self):
                pass

        class Metaclass(type):
            pass

        class CallableMetaclass(type):
            def __call__(cls):
                pass

        self.assertTrue(inspect_utils.isconstructor(OrdinaryClass))
        self.assertTrue(inspect_utils.isconstructor(OrdinaryCallableClass))
        self.assertTrue(inspect_utils.isconstructor(Metaclass))
        self.assertTrue(inspect_utils.isconstructor(Metaclass("TestClass", (), {})))
        self.assertTrue(inspect_utils.isconstructor(CallableMetaclass))

        self.assertFalse(inspect_utils.isconstructor(CallableMetaclass("TestClass", (), {})))

    def test_isconstructor_abc_callable(self):
        @six.add_metaclass(abc.ABCMeta)
        class AbcBase(object):
            @abc.abstractmethod
            def __call__(self):
                pass

        class AbcSubclass(AbcBase):
            def __init__(self):
                pass

            def __call__(self):
                pass

        self.assertTrue(inspect_utils.isconstructor(AbcBase))
        self.assertTrue(inspect_utils.isconstructor(AbcSubclass))

    def test_getfutureimports_functions(self):
        imps = inspect_utils.getfutureimports(basic_definitions.function_with_print)
        self.assertNotIn("absolute_import", imps)
        self.assertNotIn("division", imps)
        self.assertNotIn("print_function", imps)
        self.assertNotIn("generators", imps)

    def test_getfutureimports_lambdas(self):
        imps = inspect_utils.getfutureimports(basic_definitions.simple_lambda)
        self.assertNotIn("absolute_import", imps)
        self.assertNotIn("division", imps)
        self.assertNotIn("print_function", imps)
        self.assertNotIn("generators", imps)

    def test_getfutureimports_methods(self):
        imps = inspect_utils.getfutureimports(basic_definitions.SimpleClass.method_with_print)
        self.assertNotIn("absolute_import", imps)
        self.assertNotIn("division", imps)
        self.assertNotIn("print_function", imps)
        self.assertNotIn("generators", imps)
